#pragma once

#include <set>

#include "megbrain/tensor.h"

namespace mgb {
namespace opr {

/*!
 * GroupInfo: stream and shape information from all ranks of a group
 */
class GroupInfo {
public:
    struct OprInfo {
        uint64_t comp_node_hash;
        bool is_root;
        int rank;
    };

    void add_opr(
            const std::string& key, size_t nr_expected_devices, bool is_root, int rank,
            uint64_t comp_node_hash);

    void set_output_shape(const std::string& key, const TensorShape& shape);

    TensorShape get_output_shape(const std::string& key);

    void clear();

    const std::vector<OprInfo>& opr_infos() const { return m_opr_infos; }

    int get_root_rank() const { return m_root_rank; }
    int get_rank(uint64_t hash) const { return m_rank_map.at(hash); }
    uint64_t get_group_hash() const { return m_hash; }

private:
    void sort_opr_infos();
    void gen_infos_from_opr_infos();

    std::vector<OprInfo> m_opr_infos;
    std::unordered_map<uint64_t, int> m_rank_map;
    uint64_t m_hash;
    uint32_t m_nr_registered_devs;
    uint32_t m_nr_expected_devs;
    Maybe<TensorShape> m_output_shape;

    uint32_t m_count = 0;
    int m_root_rank = -1;
    std::mutex m_group_mtx;
    std::condition_variable m_register_cv;
    std::condition_variable m_clear_cv;

    std::mutex m_output_shape_mtx;
    std::condition_variable m_output_shape_cv;
};

/*!
 * GroupManager: build groups and exchange meta information
 */
class GroupManager {
public:
    ~GroupManager() = default;

    struct RegisterInfo {
        uint64_t hash;
        int rank, root_rank;
    };

    //! register oprs' info to server, return deduplicated hash
    RegisterInfo opr_register(
            const std::string& key, size_t nr_devices, bool is_root, int rank,
            uint64_t comp_node_hash);

    //! broadcast master_ip and port
    void bcast_addr(
            std::string& master_ip, int& port, const std::string& key, uint32_t size,
            uint32_t rank, uint32_t root);

    //! bcast uid
    void bcast_nccluniqueid(
            const std::string& key, std::string& id, uint32_t size, uint32_t rank,
            uint32_t root);

    //! Set output shape of this key
    void set_output_shape(const std::string& key, const TensorShape& shape);

    //! Get output shape of this key, blocks until output shape is set
    TensorShape get_output_shape(const std::string& key);

    //! Block clients until all ranks reach this barrier
    uint32_t group_barrier(uint32_t size, uint32_t rank);

private:
    GroupInfo& get_group(const std::string& key);

    //! key -> group info.
    std::unordered_map<std::string, GroupInfo> m_key2group_info;
    std::mutex m_key2group_info_mtx;

    //! key -> addr
    std::unordered_map<std::string, std::string> m_key2master_ip;
    std::unordered_map<std::string, int> m_key2port;
    std::unordered_map<std::string, uint32_t> m_key2addr_size;
    std::unordered_map<std::string, bool> m_key2addr_flag;
    std::mutex m_key2addr_mtx;
    std::condition_variable m_bcast_cv;

    //! key -> ncclid
    std::unordered_map<std::string, std::string> m_key2nccl_id;
    std::unordered_map<std::string, uint32_t> m_key2nccl_id_size;
    std::unordered_map<std::string, bool> m_key2nccl_id_flag;
    std::mutex m_key2nccl_id_mtx;

    //! barrier
    uint32_t m_barrier_size;
    std::set<uint32_t> m_barrier_set;
    std::mutex m_barrier_mtx;
    std::condition_variable m_barrier_cv;
};

/*!
 * Client interface to interact with GroupManager.
 * All the methods below should be overrided by subclasses
 * Test cases mock the interface to directly interact with GroupManager
 */
class GroupClient {
protected:
    virtual ~GroupClient() = default;

public:
    virtual const std::string& get_addr() const = 0;

    virtual GroupManager::RegisterInfo opr_register(
            const std::string& key, size_t nr_devices, bool is_root, int rank,
            uint64_t comp_node_hash) = 0;

    virtual void bcast_addr(
            std::string& master_ip, int& port, const std::string& key, uint32_t size,
            uint32_t rank, uint32_t root) = 0;

    virtual void bcast_nccluniqueid(
            const std::string& key, std::string& id, uint32_t size, uint32_t rank,
            uint32_t root) = 0;

    virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0;

    virtual TensorShape get_output_shape(const std::string& key) = 0;

    virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0;
};

/*!
 * Cache RegisterInfo returned from GroupManager. This feature is only enabled
 * in imperative runtime mode, so that multi-machine operators do not have to
 * call opr_register repeatedly in each iter
 */
namespace RegInfoCache {

static std::mutex mtx;
static std::unordered_map<std::string, GroupManager::RegisterInfo> key2info;

void set_info(const std::string& key, const GroupManager::RegisterInfo& info);
bool has_info(const std::string& key);
GroupManager::RegisterInfo get_info(const std::string& key);

}  // namespace RegInfoCache

}  // namespace opr
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
